# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional

from ...dist_utils import master_only
from ..hook import HOOKS
from .base import LoggerHook


@HOOKS.register_module()
class NeptuneLoggerHook(LoggerHook):
    """Class to log metrics to NeptuneAI.

    It requires `Neptune`_ to be installed.

    Args:
        init_kwargs (dict): a dict contains the initialization keys as below:

            - project (str): Name of a project in a form of
              namespace/project_name. If None, the value of NEPTUNE_PROJECT
              environment variable will be taken.
            - api_token (str): User’s API token. If None, the value of
              NEPTUNE_API_TOKEN environment variable will be taken. Note: It is
              strongly recommended to use NEPTUNE_API_TOKEN environment
              variable rather than placing your API token in plain text in your
              source code.
            - name (str, optional, default is 'Untitled'): Editable name of the
              run. Name is displayed in the run's Details and in Runs table as
              a column.

            Check https://docs.neptune.ai/api-reference/neptune#init for more
            init arguments.
        interval (int): Logging interval (every k iterations). Default: 10.
        ignore_last (bool): Ignore the log of last iterations in each epoch
            if less than ``interval``. Default: True.
        reset_flag (bool): Whether to clear the output buffer after logging.
            Default: True.
        with_step (bool): If True, the step will be logged from
            ``self.get_iters``. Otherwise, step will not be logged.
            Default: True.
        by_epoch (bool): Whether EpochBasedRunner is used. Default: True.

    .. _Neptune:
        https://docs.neptune.ai
    """

    def __init__(self,
                 init_kwargs: Optional[Dict] = None,
                 interval: int = 10,
                 ignore_last: bool = True,
                 reset_flag: bool = True,
                 with_step: bool = True,
                 by_epoch: bool = True):

        super().__init__(interval, ignore_last, reset_flag, by_epoch)
        self.import_neptune()
        self.init_kwargs = init_kwargs
        self.with_step = with_step

    def import_neptune(self) -> None:
        try:
            import neptune.new as neptune
        except ImportError:
            raise ImportError(
                'Please run "pip install neptune-client" to install neptune')
        self.neptune = neptune
        self.run = None

    @master_only
    def before_run(self, runner) -> None:
        if self.init_kwargs:
            self.run = self.neptune.init(**self.init_kwargs)
        else:
            self.run = self.neptune.init()

    @master_only
    def log(self, runner) -> None:
        tags = self.get_loggable_tags(runner)
        if tags:
            for tag_name, tag_value in tags.items():
                if self.with_step:
                    self.run[tag_name].log(  # type: ignore
                        tag_value, step=self.get_iter(runner))
                else:
                    tags['global_step'] = self.get_iter(runner)
                    self.run[tag_name].log(tags)  # type: ignore

    @master_only
    def after_run(self, runner) -> None:
        self.run.stop()  # type: ignore
